"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""

import importlib.metadata

from torch import Tensor

if "0.15.2" in importlib.metadata.version("torchvision"):
    import torchvision

    torchvision.disable_beta_transforms_warning()

    from torchvision.datapoints import BoundingBox as BoundingBoxes
    from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video
    from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes

    _boxes_keys = ["format", "spatial_size"]

elif "0.17" > importlib.metadata.version("torchvision") >= "0.16":
    import torchvision

    torchvision.disable_beta_transforms_warning()

    from torchvision.transforms.v2 import SanitizeBoundingBoxes
    from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video

    _boxes_keys = ["format", "canvas_size"]

elif importlib.metadata.version("torchvision") >= "0.17":
    import torchvision
    from torchvision.transforms.v2 import SanitizeBoundingBoxes
    from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video

    _boxes_keys = ["format", "canvas_size"]

else:
    raise RuntimeError("Please make sure torchvision version >= 0.15.2")


def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor:
    """
    Args:
        tensor (Tensor): input tensor
        key (str): transform to key

    Return:
        Dict[str, TV_Tensor]
    """
    assert key in (
        "boxes",
        "masks",
    ), "Only support 'boxes' and 'masks'"

    if key == "boxes":
        box_format = getattr(BoundingBoxFormat, box_format.upper())
        _kwargs = dict(zip(_boxes_keys, [box_format, spatial_size]))
        return BoundingBoxes(tensor, **_kwargs)

    if key == "masks":
        return Mask(tensor)
